{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "fe3fb743",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-10-18T14:03:18.835791Z",
     "start_time": "2023-10-18T14:03:17.555460Z"
    }
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1cdcaae2",
   "metadata": {},
   "source": [
    "## 多项式分布（Multinomial distribution）"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "220add33",
   "metadata": {},
   "source": [
    "- 离散型概率分布\n",
    "    - PMF\n",
    "$$\n",
    "\\begin{align}\n",
    "f(x_1,\\ldots,x_k;n,p_1,\\ldots,p_k) & {} = \\Pr(X_1 = x_1 \\text{ and } \\dots \\text{ and } X_k = x_k) \\\\\n",
    "& {} = \\begin{cases} { \\displaystyle {n! \\over x_1!\\cdots x_k!}p_1^{x_1}\\times\\cdots\\times p_k^{x_k}}, \\quad &\n",
    "\\text{when } \\sum_{i=1}^k x_i=n \\\\  \\\\\n",
    "0 & \\text{otherwise,} \\end{cases}\n",
    "\\end{align}\n",
    "$$"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7d458c2c",
   "metadata": {},
   "source": [
    "## `np.random.multinomial`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "3daad4d2",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-10-18T14:03:42.705407Z",
     "start_time": "2023-10-18T14:03:42.693522Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[2, 4, 0, 4, 5, 5],\n",
       "       [2, 6, 2, 6, 2, 2]])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X = np.random.multinomial(20, [1/6.]*6, size=2)\n",
    "X"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e514c600",
   "metadata": {},
   "source": [
    "- 第一个参数是 `n`，第二个参数是 `p_1, p_2, ..., p_k`，返回的是 `x_1, x_2, ..., x_k`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "5a69af61",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-10-18T14:03:51.145636Z",
     "start_time": "2023-10-18T14:03:51.137317Z"
    },
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([20, 20])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X.sum(axis=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "30ea0bfc",
   "metadata": {},
   "source": [
    "## `torch.multinomial`"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ca193ddc",
   "metadata": {},
   "source": [
    "- https://pytorch.org/docs/stable/generated/torch.multinomial.html"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c10f4e81",
   "metadata": {},
   "source": [
    "- 输入的是 weights（未必是 probs 分布，加和为1），输出的是（基于weights）采样得到的 index；\n",
    "    - must be non-negative, finite and have a non-zero sum.\n",
    "    - 内部其实还是会将 weights normalize；\n",
    "- replacement：默认为 false，无放回抽样，一个index只会被抽样一次；"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "0aa6c423",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-10-18T14:30:09.266161Z",
     "start_time": "2023-10-18T14:30:09.260414Z"
    }
   },
   "outputs": [],
   "source": [
    "weights = torch.tensor([0, 10, 3, 0], dtype=torch.float)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "c0f3c323",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-10-18T14:52:58.026357Z",
     "start_time": "2023-10-18T14:52:58.014991Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([2, 1])"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.multinomial(weights, 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "393770da",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-10-18T14:53:41.312288Z",
     "start_time": "2023-10-18T14:53:41.301568Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([2, 1, 0, 3])"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.multinomial(weights, 4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "1dd348b9",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-10-18T14:54:06.912736Z",
     "start_time": "2023-10-18T14:54:06.901721Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([2, 1, 1, 1])"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.multinomial(weights, 4, replacement=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "487724f0",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-10-18T14:40:41.183285Z",
     "start_time": "2023-10-18T14:40:09.179000Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.0000\n"
     ]
    }
   ],
   "source": [
    "torch.manual_seed(0)\n",
    "probs = torch.tensor([1.0] + [0.0] * 999,  dtype=torch.double)\n",
    "wrongs = 0\n",
    "for i in range(1000000):\n",
    "    sampled = torch.multinomial(probs, num_samples=1)\n",
    "    if sampled != 0:\n",
    "        wrongs += 1\n",
    "\n",
    "print(f\"{100 * wrongs / 1000000:.4f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "864aa960",
   "metadata": {},
   "source": [
    "## `torch.distributions.categorical.Categorical`"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f93b57d8",
   "metadata": {},
   "source": [
    "- https://pytorch.org/docs/stable/distributions.html"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "a86b0f09",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-10-18T14:29:31.957540Z",
     "start_time": "2023-10-18T14:29:31.950887Z"
    }
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.distributions import Categorical\n",
    "from collections import Counter\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "8df8ca92",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-10-18T14:30:32.040248Z",
     "start_time": "2023-10-18T14:30:32.032335Z"
    }
   },
   "outputs": [],
   "source": [
    "c = Categorical(weights)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "b86708bf",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-10-18T14:30:34.207727Z",
     "start_time": "2023-10-18T14:30:34.174885Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.0000, 0.7692, 0.2308, 0.0000])"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "c.probs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "9fa91ad0",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-10-18T14:32:59.330044Z",
     "start_time": "2023-10-18T14:32:59.316032Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[tensor(1),\n",
       " tensor(2),\n",
       " tensor(1),\n",
       " tensor(1),\n",
       " tensor(1),\n",
       " tensor(1),\n",
       " tensor(2),\n",
       " tensor(1),\n",
       " tensor(1),\n",
       " tensor(2)]"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "[c.sample() for _ in range(10)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "9dab36bc-c651-4b8e-afbc-e8abe9d2b9c1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Counter({2: 137, 3: 128, 1: 119, 0: 116})\n",
      "Counter({1: 134, 3: 128, 2: 125, 0: 113})\n"
     ]
    }
   ],
   "source": [
    "# 类别型概率分布的采样\n",
    "c = Counter()\n",
    "m = Categorical(probs=torch.tensor([0.25, 0.25, 0.25, 0.25]))\n",
    "for _ in range(500):\n",
    "    c.update([m.sample().numpy().tolist()])\n",
    "print(c)\n",
    "print(Counter(m.sample((500, )).numpy().tolist()))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
